use super::{task_from_ctx, Attr, FnOnceFuture, TaskRef};
use crate::Error;
use core::alloc::Layout;
use core::future::Future;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::ptr::{self, NonNull};
use core::task::{Context, Poll};

/// 注意返回值类型为Result<Future::Output, Error>.
/// 业务层负责处理异步调度失败的返回值.
/// 同步环境中也可以使用，调用JoinHandle::join获取返回值.
pub struct JoinHandle<T> {
    task: Option<TaskRef>,
    mark: PhantomData<*const T>,
}

unsafe impl<T: Send> Send for JoinHandle<T> {}

impl<T> Future for JoinHandle<T> {
    type Output = Result<T, Error>;
    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        if let Some(ref task) = self.task {
            let mut output = MaybeUninit::<T>::uninit();
            match unsafe { task.output(ctx.waker(), output.as_mut_ptr().cast::<()>()) } {
                Poll::Ready(Ok(())) => Poll::Ready(Ok(unsafe { output.assume_init_read() })),
                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
                Poll::Pending => Poll::Pending,
            }
        } else {
            Poll::Ready(Err(Error::default()))
        }
    }
}

impl<T> JoinHandle<T> {
    pub(crate) fn new(task: TaskRef) -> Self {
        Self {
            task: Some(task),
            mark: PhantomData,
        }
    }

    pub(crate) fn null() -> Self {
        Self {
            task: None,
            mark: PhantomData,
        }
    }

    /// 应该只在同步环境中调用此接口.
    pub fn join(&self) -> <Self as Future>::Output {
        if let Some(ref task) = self.task {
            let mut output = MaybeUninit::<T>::uninit();
            unsafe {
                task.join(output.as_mut_ptr().cast::<()>())
                    .map(|_| output.assume_init_read())
            }
        } else {
            Err(Error::default())
        }
    }

    /// 取消任务，如果成功取消则返回TRUE，否则说明异步任务已经执行完毕，取消失败.
    pub fn abort(&self) -> bool {
        if let Some(ref task) = self.task {
            return task.abort();
        }
        true
    }

    /// 如果异步任务已经执行完毕，返回TRUE，否则返回FALSE.
    pub fn is_finished(&self) -> bool {
        if let Some(ref task) = self.task {
            task.is_finished()
        } else {
            true
        }
    }

    /// 返回AbortHandle，可以按需取消任务执行.
    pub fn abort_handle(mut self) -> AbortHandle<T> {
        AbortHandle::<T> {
            task: self.task.take(),
            mark: PhantomData,
        }
    }
}

pub struct AbortHandle<T> {
    task: Option<TaskRef>,
    mark: PhantomData<*const T>,
}

unsafe impl<T: Send> Send for AbortHandle<T> {}

impl<T> AbortHandle<T> {
    pub fn abort(&self) -> bool {
        if let Some(ref task) = self.task {
            return task.abort();
        }
        true
    }
    pub fn is_finished(&self) -> bool {
        if let Some(ref task) = self.task {
            task.is_finished()
        } else {
            true
        }
    }
}

/// 提供批量任务的创建工作，可以更高效的等待全部任务或者任何一个子任务的结束.
/// 不提供join操作，只能在异步环境中使用.
#[repr(C)]
pub struct JoinSet<T> {
    running: JoinList<T>,
    exited: JoinList<T>,
    jcnt: u32,
}

unsafe impl<T> Send for JoinSet<T> {}

impl<T> Default for JoinSet<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> Drop for JoinSet<T> {
    fn drop(&mut self) {
        self.abort();
        for _ in Iter::new(self) {}
    }
}

impl<T> JoinSet<T> {
    pub const fn new() -> Self {
        Self {
            jcnt: 0,
            running: JoinList::new(),
            exited: JoinList::new(),
        }
    }

    pub fn spawn<F>(&mut self, future: F) -> Result<(), Error>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        self.spawn_with(future, &Attr::default())
    }

    pub fn spawn_with<F>(&mut self, future: F, attr: &Attr) -> Result<(), Error>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        let attr = Self::join_attr(attr);
        let handle = super::spawn_with(future, &attr);
        self.push_handle(handle)
    }

    pub fn spawn_local<F>(&mut self, future: F) -> Result<(), Error>
    where
        F: Future<Output = T> + 'static,
        T: 'static,
    {
        self.spawn_local_with(future, &Attr::default())
    }

    pub fn spawn_local_with<F>(&mut self, future: F, attr: &Attr) -> Result<(), Error>
    where
        F: Future<Output = T> + 'static,
        T: 'static,
    {
        let attr = Self::join_attr(attr);
        let handle = super::spawn_local_with(future, &attr);
        self.push_handle(handle)
    }

    pub fn spawn_fn<F>(&mut self, f: F) -> Result<(), Error>
    where
        F: FnOnce() -> T + Send + 'static,
        T: Send + 'static,
    {
        self.spawn(FnOnceFuture::new(f))
    }

    pub fn spawn_fn_with<F>(&mut self, f: F, attr: &Attr) -> Result<(), Error>
    where
        F: FnOnce() -> T + Send + 'static,
        T: Send + 'static,
    {
        self.spawn_with(FnOnceFuture::new(f), attr)
    }

    pub fn spawn_fn_local<F>(&mut self, f: F) -> Result<(), Error>
    where
        F: FnOnce() -> T + Send + 'static,
        T: Send + 'static,
    {
        self.spawn_local(FnOnceFuture::new(f))
    }

    pub fn spawn_fn_local_with<F>(&mut self, f: F, attr: &Attr) -> Result<(), Error>
    where
        F: FnOnce() -> T + Send + 'static,
        T: Send + 'static,
    {
        self.spawn_local_with(FnOnceFuture::new(f), attr)
    }

    pub fn abort(&mut self) {
        while let Some(node) = self.running.pop() {
            node.abort();
            unsafe {
                ptr::drop_in_place(node);
            }
        }
    }

    /// 等待所有任务结束后返回每个任务的执行结果.
    /// 执行结果的数据类型为(id, Result<T, Error>), id为JoinSet::spawn, 即任务创建的顺序，从0开始.
    /// 相对顺序等待每个任务的结束，比如:
    /// `
    /// for h in handles {
    ///     h.await;
    /// }
    /// `
    /// 本方法更加高效，只有全部子任务都结束后才会唤醒本任务.
    pub async fn wait_all(&mut self) -> impl Iterator<Item = (usize, Result<T, Error>)> + '_ {
        WaitAll::new(self).await;
        Iter::new(self)
    }

    /// 等待任何一个任务结束后即返回任务的执行结果.
    /// 执行结果的数据类型为(id, Result<T, Error>), id为JoinSet::spawn, 即任务创建的顺序，从0开始.
    /// 相对顺序等待每个任务的结束，比如:
    /// `
    /// for h in handles {
    ///     h.await;
    /// }
    /// `
    /// 本方法更加高效，其返回顺序取决于任务结束顺序而非任务的创建顺序
    pub async fn wait_any(&mut self) -> Option<(usize, Result<T, Error>)> {
        WaitAny::new(self).await
    }

    fn try_read_output(&mut self, ctx: &mut Context<'_>) -> u8 {
        let mut wcnt = 0;
        let mut waiting = JoinList::new();
        while let Some(node) = self.running.pop() {
            if node.read_output(ctx) {
                self.exited.push(node);
            } else {
                waiting.push(node);
                wcnt += 1;
                if wcnt == 128 {
                    break;
                }
            }
        }
        waiting.move_head(&mut self.running);
        wcnt
    }

    fn join_attr(attr: &Attr) -> Attr {
        let layout = Layout::new::<JoinNode<T>>();
        let mut attr = attr.clone();
        let _ = attr.tail(layout);
        attr
    }

    fn push_handle(&mut self, mut handle: JoinHandle<T>) -> Result<(), Error> {
        if handle.task.is_none() {
            return Err(Error::default());
        }
        let layout = Layout::new::<JoinNode<T>>();
        let tail = unsafe { handle.task.as_ref().unwrap().tail(layout) };
        let node = unsafe { &mut *tail.cast_mut().cast::<MaybeUninit<JoinNode<T>>>() };
        let node = node.write(JoinNode::new(handle.task.take(), self.jcnt));
        self.running.push(node);
        self.jcnt += 1;
        Ok(())
    }
}

struct WaitAll<'a, T> {
    set: &'a mut JoinSet<T>,
}

impl<'a, T> WaitAll<'a, T> {
    fn new(set: &'a mut JoinSet<T>) -> Self {
        Self { set }
    }
}

impl<T> Future for WaitAll<'_, T> {
    type Output = ();
    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        let wcnt = self.set.try_read_output(ctx);
        let task = unsafe { task_from_ctx(ctx).as_mut() };
        if wcnt > 0 {
            task.status.set_wake_expect(wcnt);
            Poll::Pending
        } else {
            task.status.set_wake_expect(1);
            Poll::Ready(())
        }
    }
}

struct WaitAny<'a, T> {
    set: &'a mut JoinSet<T>,
}

impl<'a, T> WaitAny<'a, T> {
    fn new(set: &'a mut JoinSet<T>) -> Self {
        Self { set }
    }
}

impl<T> Future for WaitAny<'_, T> {
    type Output = Option<(usize, Result<T, Error>)>;
    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        let _ = self.set.try_read_output(ctx);
        if let Some(node) = self.set.exited.pop() {
            let ret = node.get();
            unsafe {
                ptr::drop_in_place(node);
            }
            Poll::Ready(Some(ret))
        } else if self.set.running.empty() {
            Poll::Ready(None)
        } else {
            Poll::Pending
        }
    }
}

struct Iter<'a, T> {
    set: &'a mut JoinSet<T>,
}

impl<'a, T> Iter<'a, T> {
    fn new(set: &'a mut JoinSet<T>) -> Self {
        Self { set }
    }
}

impl<T> Iterator for Iter<'_, T> {
    type Item = (usize, Result<T, Error>);
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(node) = self.set.exited.pop() {
            let ret = node.get();
            unsafe {
                ptr::drop_in_place(node);
            }
            return Some(ret);
        }
        None
    }
}

#[repr(C)]
struct JoinNode<T> {
    task: Option<TaskRef>,
    output: Option<Result<T, Error>>,
    next: Option<NonNull<JoinNode<T>>>,
    id: u32,
}

impl<T> JoinNode<T> {
    fn new(task: Option<TaskRef>, id: u32) -> Self {
        Self {
            id,
            task,
            output: None,
            next: None,
        }
    }

    fn get(&mut self) -> (usize, Result<T, Error>) {
        (self.id as usize, self.output.take().unwrap())
    }

    fn abort(&mut self) {
        self.task.as_ref().unwrap().abort();
    }

    fn read_output(&mut self, ctx: &mut Context<'_>) -> bool {
        let mut output = MaybeUninit::<T>::uninit();
        let task = self.task.as_ref().unwrap();
        match unsafe { task.output(ctx.waker(), output.as_mut_ptr().cast::<()>()) } {
            Poll::Ready(Ok(())) => {
                self.output = Some(Ok(unsafe { output.assume_init_read() }));
                true
            }
            Poll::Ready(Err(err)) => {
                self.output = Some(Err(err));
                true
            }
            Poll::Pending => false,
        }
    }
}

struct JoinList<T> {
    first: Option<NonNull<JoinNode<T>>>,
    last: Option<NonNull<JoinNode<T>>>,
}

impl<T> JoinList<T> {
    const fn new() -> Self {
        Self {
            first: None,
            last: None,
        }
    }
    fn push(&mut self, node: &mut JoinNode<T>) {
        let node = NonNull::from(node);
        if let Some(mut last) = self.last {
            let last = unsafe { last.as_mut() };
            last.next = Some(node);
        } else {
            self.first = Some(node);
        }
        self.last = Some(node);
    }

    fn pop(&mut self) -> Option<&mut JoinNode<T>> {
        if let Some(mut first) = self.first {
            let first = unsafe { first.as_mut() };
            if self.last == self.first {
                self.first = None;
                self.last = None;
            } else {
                self.first = first.next;
            }
            return Some(first);
        }
        None
    }

    fn move_head(&mut self, dst: &mut Self) {
        if let Some(mut last) = self.last {
            let last = unsafe { last.as_mut() };
            last.next = dst.first;
            dst.first = self.first;
            self.last = None;
            self.first = None;
        }
    }

    fn empty(&self) -> bool {
        self.first.is_none()
    }
}
